/***************************************************************************
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay          *
* Copyright (c) QuantStack                                                 *
*                                                                          *
* Distributed under the terms of the BSD 3-Clause License.                 *
*                                                                          *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/

// This file is generated from test/files/cppy_source/test_lstsq.cppy by preprocess.py!

#include <algorithm>

#include "gtest/gtest.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xfixed.hpp"
#include "xtensor/xnoalias.hpp"
#include "xtensor/xstrided_view.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xview.hpp"

#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
    using namespace xt::placeholders;

    /*py
    a = np.random.random((6, 3))
    b = np.ones((6))
    */
    TEST(xtest_extended, lstsq1)
    {
        // py_a
        xarray<double> py_a = {{0.3745401188473625,0.9507143064099162,0.7319939418114051},
                               {0.5986584841970366,0.1560186404424365,0.1559945203362026},
                               {0.0580836121681995,0.8661761457749352,0.6011150117432088},
                               {0.7080725777960455,0.0205844942958024,0.9699098521619943},
                               {0.8324426408004217,0.2123391106782762,0.1818249672071006},
                               {0.1834045098534338,0.3042422429595377,0.5247564316322378}};
        // py_b
        xarray<double> py_b = {1.,1.,1.,1.,1.,1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {0.99525656797683  ,0.6379298291900684,0.416589303565964 };
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {0.3378625895661748};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 3;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {2.081504268698353,1.012756249516551,0.599044658280111};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((3, 3))
    b = np.ones((3))
    */
    TEST(xtest_extended, lstsq20)
    {
        // py_a
        xarray<double> py_a = {{0.4319450186421158,0.2912291401980419,0.6118528947223795},
                               {0.1394938606520418,0.2921446485352182,0.3663618432936917},
                               {0.4560699842170359,0.7851759613930136,0.1996737821583597}};
        // py_b
        xarray<double> py_b = {1.,1.,1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {-1.655587220862159 , 1.7320451450169407, 1.9787446378934206};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 3;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.2339483753871052,0.4580824861786693,0.1291723342275802};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);

        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((3, 3))
    b = np.ones((3, 3))
    */
    TEST(xtest_extended, lstsq21)
    {
        // py_a
        xarray<double> py_a = {{0.5142344384136116,0.5924145688620425,0.0464504127199977},
                               {0.6075448519014384,0.1705241236872915,0.0650515929852795},
                               {0.9488855372533332,0.9656320330745594,0.8083973481164611}};
        // py_b
        xarray<double> py_b = {{1.,1.,1.},
                               {1.,1.,1.},
                               {1.,1.,1.}};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {{ 1.6749237812267237, 1.6749237812267237, 1.6749237812267237},
                                  { 0.3213797243357512, 0.3213797243357512, 0.3213797243357512},
                                  {-1.1128753832544371,-1.1128753832544371,-1.1128753832544371}};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 3;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.8090476189892228,0.4005423925178662,0.2705890168670333};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);

        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((2, 5))
    b = np.ones((2))
    */
    TEST(xtest_extended, lstsq3)
    {
        // py_a
        xarray<double> py_a = {{0.3046137691733707,0.0976721140063839,0.6842330265121569,
                                0.4401524937396013,0.1220382348447788},
                               {0.4951769101112702,0.0343885211152184,0.9093204020787821,
                                0.2587799816000169,0.662522284353982 }};
        // py_b
        xarray<double> py_b = {1.,1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = { 0.3137661125421979, 0.183749537801855 , 0.8404557593671863,
                                   0.7586648365305537,-0.1845363594995904};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 2;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.4931292414997537,0.3589512974668556};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((2, 5))
    b = np.ones((2, 10))
    */
    TEST(xtest_extended, lstsq4)
    {
        // py_a
        xarray<double> py_a = {{0.311711076089411 ,0.5200680211778108,0.5467102793432796,
                                0.184854455525527 ,0.9695846277645586},
                               {0.7751328233611146,0.9394989415641891,0.8948273504276488,
                                0.5978999788110851,0.9218742350231168}};
        // py_b
        xarray<double> py_b = {{1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {{-0.0723929848964098,-0.0723929848964098,-0.0723929848964098,
                                   -0.0723929848964098,-0.0723929848964098,-0.0723929848964098,
                                   -0.0723929848964098,-0.0723929848964098,-0.0723929848964098,
                                   -0.0723929848964098},
                                  { 0.1423971374668718, 0.1423971374668718, 0.1423971374668718,
                                    0.1423971374668718, 0.1423971374668718, 0.1423971374668718,
                                    0.1423971374668718, 0.1423971374668718, 0.1423971374668718,
                                    0.1423971374668718},
                                  { 0.2187317829605842, 0.2187317829605842, 0.2187317829605842,
                                    0.2187317829605842, 0.2187317829605842, 0.2187317829605842,
                                    0.2187317829605842, 0.2187317829605842, 0.2187317829605842,
                                    0.2187317829605842},
                                  {-0.1457627271119433,-0.1457627271119433,-0.1457627271119433,
                                   -0.1457627271119433,-0.1457627271119433,-0.1457627271119433,
                                   -0.1457627271119433,-0.1457627271119433,-0.1457627271119433,
                                   -0.1457627271119433},
                                  { 0.882719722037499 , 0.882719722037499 , 0.882719722037499 ,
                                    0.882719722037499 , 0.882719722037499 , 0.882719722037499 ,
                                    0.882719722037499 , 0.882719722037499 , 0.882719722037499 ,
                                    0.882719722037499 }};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 2;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {2.23042850951828  ,0.3968910268428817};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((10, 5))
    b = np.ones((10, 20))
    */
    TEST(xtest_extended, lstsq5)
    {
        // py_a
        xarray<double> py_a = {{0.0884925020519195,0.1959828624191452,0.0452272889105381,
                                0.3253303307632643,0.388677289689482 },
                               {0.2713490317738959,0.8287375091519293,0.3567533266935893,
                                0.2809345096873808,0.5426960831582485},
                               {0.1409242249747626,0.8021969807540397,0.0745506436797708,
                                0.9868869366005173,0.7722447692966574},
                               {0.1987156815341724,0.0055221171236024,0.8154614284548342,
                                0.7068573438476171,0.7290071680409873},
                               {0.7712703466859457,0.0740446517340904,0.3584657285442726,
                                0.1158690595251297,0.8631034258755935},
                               {0.6232981268275579,0.3308980248526492,0.0635583502860236,
                                0.3109823217156622,0.325183322026747 },
                               {0.7296061783380641,0.6375574713552131,0.8872127425763265,
                                0.4722149251619493,0.1195942459383017},
                               {0.713244787222995 ,0.7607850486168974,0.5612771975694962,
                                0.770967179954561 ,0.4937955963643907},
                               {0.5227328293819941,0.4275410183585496,0.0254191267440952,
                                0.1078914269933045,0.0314291856867343},
                               {0.6364104112637804,0.3143559810763267,0.5085706911647028,
                                0.907566473926093 ,0.2492922291488749}};
        // py_b
        xarray<double> py_b = {{1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.},
                               {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {{ 0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
                                    0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
                                    0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
                                    0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
                                    0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
                                    0.7695214798127482, 0.7695214798127483, 0.7695214798127483,
                                    0.7695214798127483, 0.7695214798127483},
                                  { 0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
                                    0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
                                    0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
                                    0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
                                    0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
                                    0.3603784058338763, 0.3603784058338762, 0.3603784058338762,
                                    0.3603784058338763, 0.3603784058338763},
                                  {-0.0288908468951092,-0.0288908468951092,-0.0288908468951092,
                                   -0.0288908468951092,-0.0288908468951092,-0.0288908468951092,
                                   -0.0288908468951092,-0.0288908468951092,-0.0288908468951092,
                                   -0.0288908468951092,-0.0288908468951092,-0.0288908468951092,
                                   -0.0288908468951092,-0.0288908468951092,-0.0288908468951092,
                                   -0.0288908468951092,-0.0288908468951093,-0.0288908468951093,
                                   -0.0288908468951092,-0.0288908468951092},
                                  { 0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
                                    0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
                                    0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
                                    0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
                                    0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
                                    0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
                                    0.2739420182164652, 0.2739420182164652},
                                  { 0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
                                    0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
                                    0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
                                    0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
                                    0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
                                    0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
                                    0.6381721647626307, 0.6381721647626307}};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {0.6683875034141331,0.6683875034141331,0.6683875034141331,
                                  0.6683875034141331,0.6683875034141331,0.6683875034141331,
                                  0.6683875034141331,0.6683875034141331,0.6683875034141331,
                                  0.6683875034141331,0.6683875034141331,0.6683875034141331,
                                  0.6683875034141331,0.6683875034141331,0.6683875034141331,
                                  0.6683875034141331,0.6683875034141331,0.6683875034141331,
                                  0.6683875034141331,0.6683875034141331};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 5;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {3.317877520855451 ,1.0262463009257718,0.9696565206896536,
                                  0.8384020117545181,0.5915006407947916};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.array([[0., 1.]])
    b = np.array([1.])
    */
    TEST(xtest_extended, lstsq6)
    {
        // py_a
        xarray<double> py_a = {{0.,1.}};
        // py_b
        xarray<double> py_b = {1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {0.,1.};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 1;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.array([[1.], [1.]])
    b = np.array([1., 1.])
    */
    TEST(xtest_extended, lstsq7)
    {
        // cannot use "// py_a" due to ambiguous initializer list conversion below
        // xarray<double> py_a = {{1.},
        //                        {1.}};
        xarray<double> py_a = xt::ones<double>({2, 1});
        // py_b
        xarray<double> py_b = {1.,1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {0.9999999999999997};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {2.2508083912556065e-33};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 1;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.4142135623730951};
        
        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }


}
